import datetime
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import random
import time
from collections import deque
from itertools import count
import types
import pickle
import copy
import hydra
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from omegaconf import DictConfig, OmegaConf
from tensorboardX import SummaryWriter
from torch.autograd import Variable, grad

from logger import Logger
from make_envs import make_env
from memory import Memory
from agent import make_agent
from utils import eval_mode, get_concat_samples, evaluate, soft_update, hard_update

torch.set_num_threads(1)
COUNTER = 0
def get_args(cfg: DictConfig):
    cfg.device = "cuda:0" if torch.cuda.is_available() else "cpu"
    cfg.hydra_base_dir = os.getcwd()
    print(OmegaConf.to_yaml(cfg))
    return cfg


@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig):
    args = get_args(cfg)
    wandb.init(project=args.project_name, entity=#TODO, #Enter here your wandb profile
               sync_tensorboard=True, reinit=True, config=args)

    # set seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    env_args = args.env 
    env = make_env(args)
    eval_env = make_env(args)
    # Seed envs
    env.seed(args.seed)
    eval_env.seed(args.seed + 10)

    REPLAY_MEMORY = int(env_args.replay_mem)
    INITIAL_MEMORY = 0
    UPDATE_STEPS = int(env_args.update_steps)
    EPISODE_STEPS = int(env_args.eps_steps)
    EPISODE_WINDOW = int(env_args.eps_window)
    LEARN_STEPS = int(env_args.learn_steps)

    INITIAL_STATES = 128

    agent = make_agent(env, args)
    expert_memory_replay = Memory(REPLAY_MEMORY//2, args.seed)
    expert_memory_replay.load(hydra.utils.to_absolute_path(f'experts/{args.env.demo}'),
                              num_trajs=args.eval.demos,
                              sample_freq=args.eval.subsample_freq,
                              seed=args.seed + 42)
    print(f'--> Expert memory size: {expert_memory_replay.size()}')


    ts_str = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d_%H-%M-%S")
    log_dir = os.path.join(args.log_dir, args.env.name, args.exp_name, args.method.type, str(args.seed), ts_str)
    writer = SummaryWriter(log_dir=log_dir)
    print(f'--> Saving logs at: {log_dir}')
    # TODO: Fix logging
    logger = Logger(args.log_dir)

    steps = 0

    # track avg. reward and scores
    scores_window = deque(maxlen=EPISODE_WINDOW)  # last N scores
    rewards_window = deque(maxlen=EPISODE_WINDOW)  # last N rewards
    best_eval_returns = -np.inf

    learn_steps = 0
    begin_learn = False
    episode_reward = 0
    eval_rewards = []
    rewards = []
    ss = []
    state_0 = [env.reset()] * INITIAL_STATES
    state_0 = torch.FloatTensor(state_0).to(args.device)

    for epoch in count():
        state = env.reset()
        episode_reward = 0
        done = False
        for episode_step in range(EPISODE_STEPS):
            if steps < args.num_seed_steps:
                action = env.action_space.sample()  # Sample random action
            else:
                with eval_mode(agent):
                    action = agent.choose_action(state, sample=True)
            next_state, reward, done, _ = env.step(action)
            episode_reward += reward
            steps += 1

            if learn_steps % args.env.eval_interval == 0:
                eval_returns, eval_timesteps = evaluate(agent, eval_env, num_episodes=args.eval.eps)
                returns = np.mean(eval_returns)
                learn_steps += 1  # To prevent repeated eval at timestep 0
                writer.add_scalar('Rewards/eval_rewards', returns,
                                  global_step=learn_steps)
                eval_rewards.append(returns)
                print('EVAL\tEp {}\tAverage reward: {:.2f}\t'.format(epoch, returns))
                writer.add_scalar(
                    'Success/eval', np.mean((np.array(eval_returns) > 200)), global_step=epoch)

                if returns > best_eval_returns:
                    best_eval_returns = returns
                    wandb.run.summary["best_returns"] = best_eval_returns
                    save(agent, epoch, args, output_dir='results_best')

            # allow infinite bootstrap
            done_no_lim = done
            if str(env.__class__.__name__).find('TimeLimit') >= 0 and episode_step + 1 == env._max_episode_steps:
                done_no_lim = 0

            
            learn_steps += 1
            if learn_steps == LEARN_STEPS:
                print('Finished!')
                wandb.finish()
                return

            ######
            # IRL Modification
            agent.irl_update = types.MethodType(irl_update, agent)
            agent.ilr_update_critic = types.MethodType(ilr_update_critic, agent)
            losses = agent.irl_update(expert_memory_replay, logger, learn_steps)
            ######

            if learn_steps % args.log_interval == 0:
                for key, loss in losses.items():
                        writer.add_scalar(key, loss, global_step=learn_steps)

            if done:
                break
            state = next_state

        writer.add_scalar('episodes', epoch, global_step=learn_steps)

        rewards_window.append(episode_reward)
        scores_window.append(float(episode_reward > 200))
        writer.add_scalar('Rewards/train_reward', np.mean(rewards_window), global_step=epoch)
        writer.add_scalar('Success/train', np.mean(scores_window), global_step=epoch)

        print('TRAIN\tEp {}\tAverage reward: {:.2f}\t'.format(epoch, np.mean(rewards_window)))
        save(agent, epoch, args, output_dir='results')
        rewards.append(np.mean(rewards_window))
        ss.append(steps)

        with open("../../../pickle_results/"+args.method.type+"/"+args.env.name+str(args.seed)
                +"n_trajs"+str(env_args.replay_mem)
                +"_lr_w"+str(args.agent.critic_lr)
                +"_lr_theta"+str(args.agent.critic_lr)+".pt","wb") as f:
            print("Saving Pickle")
            pickle.dump((rewards, eval_rewards, ss), f)

def save(agent, epoch, args, output_dir='results'):
    if epoch % args.save_interval == 0:
        if args.method.type == "sqil":
            name = f'sqil_{args.env.name}'
        else:
            name = f'iq_{args.env.name}'

        if not os.path.exists(output_dir):
            os.mkdir(output_dir)
        agent.save(f'{output_dir}/{args.agent.name}_{name}')

def ilr_update_critic(self, expert_batch, logger, step):
    args = self.args
    expert_obs, expert_next_obs, expert_action, expert_reward, done = expert_batch
    

    losses = {}
    # keep track of v0
    v0 = self.getV(expert_obs).mean()
    losses['v0'] = v0.item()


    if args.method.type == "iq":
        # our method, calculate 1st term of loss
        #  -E_(ρ_expert)[Q(s, a) - γV(s')]
        current_Q = self.critic(expert_obs, expert_action)
        next_v = self.getV(expert_next_obs)
        y = (1 - done) * self.gamma * next_v

        if args.train.use_target:
            with torch.no_grad():
                next_v = self.get_targetV(expert_next_obs)
                y = (1 - done) * self.gamma * next_v

        reward = (current_Q - y)

        with torch.no_grad():
            if args.method.div == "hellinger":
                phi_grad = 1/(1+reward)**2
            elif args.method.div == "kl":
                phi_grad = torch.exp(-reward-1)
            elif args.method.div == "kl2":
                phi_grad = F.softmax(-reward, dim=0) * reward.shape[0]
            elif args.method.div == "kl_fix":
                phi_grad = torch.exp(-reward)
            elif args.method.div == "js":
                phi_grad = torch.exp(-reward)/(2 - torch.exp(-reward))
            else:
                phi_grad = 1
        loss = -(phi_grad * reward).mean()
        losses['softq_loss'] = loss.item()

        if args.method.loss == "v0":
            # calculate 2nd term for our loss
            # (1-γ)E_(ρ0)[V(s0)]
            v0_loss = (1 - self.gamma) * v0
            loss += v0_loss
            losses['v0_loss'] = v0_loss.item()
       
        elif args.method.loss == "value_expert":
            # alternative 2nd term for our loss (use only expert states)
            # E_(ρ)[Q(s,a) - γV(s')]
            value_loss = (self.getV(expert_obs) - y).mean()
            loss += value_loss
            losses['value_loss'] = value_loss.item()

    elif args.method.type == "spoil":
        Q1 = self.critic(expert_obs, expert_action)
        learner_actions = []
        for o in expert_obs:
            learner_actions.append([self.choose_action(o, sample=True)])
        learner_actions = torch.FloatTensor(learner_actions)
        Q2 = self.critic(expert_obs,learner_actions)
        loss = - Q1.mean() + Q2.mean()
        losses["value"] = loss.item()
    
    elif args.method.type == "ppil":
        if args.method.loss == "value":
            current_Q = self.critic(expert_obs, expert_action)
            current_r = self.get_reward(expert_obs, expert_action)
            next_v = self.getV(expert_next_obs)
            y = (1 - done) * self.gamma * next_v
            #First Term
            loss = - (current_r).mean()
            #Second Term
            value_loss = torch.logsumexp(10*(-current_Q + current_r + y), dim=0)/10
            #value_loss = torch.logsumexp(10*(-current_Q + current_r + y), dim=0)/10
            
            loss += value_loss[0]

            #Third Term
            #v0_loss = (1 - self.gamma) * v0
            #loss += v0_loss
            #w = args.method.mix_coeff
            #value_loss = (w * (self.getV(obs) - y)[is_expert] +
            #              (1-w) * (self.getV(obs) - y)[~is_expert]).mean()
            #loss += value_loss
            value_loss = (self.getV(expert_obs) - y).mean()
            loss += value_loss
            losses['value'] = loss.item()
    else:
        raise ValueError(f'This method is not implemented: {args.method.type}')

    losses['total_loss'] = loss.item()
    if args.method.type == "logistic_offline":
        # Optimize the critic and the reward
        self.critic_optimizer.zero_grad()
        self.reward_optimizer.zero_grad()
        loss.backward()
        # step critic
        self.critic_optimizer.step()
        self.reward_optimizer.step()
    else:
        # Optimize the critic
        self.critic_optimizer.zero_grad()
        loss.backward()
        self.critic_optimizer.step()
    return losses


def irl_update(self, expert_buffer, logger, step):
    expert_batch = expert_buffer.get_samples(self.batch_size, self.device)

    losses = self.ilr_update_critic(expert_batch, logger, step)

    return losses


if __name__ == "__main__":
    main()
